home *** CD-ROM | disk | FTP | other *** search
/ Clickx 96 / Clickx 96.iso / software / tools / tool / xbmc-10.1.exe / addons / script.module.pysqlite / lib / pysqlite2 / test / userfunctions.py < prev   
Encoding:
Python Source  |  2009-10-19  |  13.0 KB  |  414 lines

  1. #-*- coding: ISO-8859-1 -*-
  2. # pysqlite2/test/userfunctions.py: tests for user-defined functions and
  3. #                                  aggregates.
  4. #
  5. # Copyright (C) 2005-2007 Gerhard HΣring <gh@ghaering.de>
  6. #
  7. # This file is part of pysqlite.
  8. #
  9. # This software is provided 'as-is', without any express or implied
  10. # warranty.  In no event will the authors be held liable for any damages
  11. # arising from the use of this software.
  12. #
  13. # Permission is granted to anyone to use this software for any purpose,
  14. # including commercial applications, and to alter it and redistribute it
  15. # freely, subject to the following restrictions:
  16. #
  17. # 1. The origin of this software must not be misrepresented; you must not
  18. #    claim that you wrote the original software. If you use this software
  19. #    in a product, an acknowledgment in the product documentation would be
  20. #    appreciated but is not required.
  21. # 2. Altered source versions must be plainly marked as such, and must not be
  22. #    misrepresented as being the original software.
  23. # 3. This notice may not be removed or altered from any source distribution.
  24.  
  25. import unittest
  26. import pysqlite2.dbapi2 as sqlite
  27.  
  28. def func_returntext():
  29.     return "foo"
  30. def func_returnunicode():
  31.     return u"bar"
  32. def func_returnint():
  33.     return 42
  34. def func_returnfloat():
  35.     return 3.14
  36. def func_returnnull():
  37.     return None
  38. def func_returnblob():
  39.     return buffer("blob")
  40. def func_raiseexception():
  41.     5/0
  42.  
  43. def func_isstring(v):
  44.     return type(v) is unicode
  45. def func_isint(v):
  46.     return type(v) is int
  47. def func_isfloat(v):
  48.     return type(v) is float
  49. def func_isnone(v):
  50.     return type(v) is type(None)
  51. def func_isblob(v):
  52.     return type(v) is buffer
  53.  
  54. class AggrNoStep:
  55.     def __init__(self):
  56.         pass
  57.  
  58.     def finalize(self):
  59.         return 1
  60.  
  61. class AggrNoFinalize:
  62.     def __init__(self):
  63.         pass
  64.  
  65.     def step(self, x):
  66.         pass
  67.  
  68. class AggrExceptionInInit:
  69.     def __init__(self):
  70.         5/0
  71.  
  72.     def step(self, x):
  73.         pass
  74.  
  75.     def finalize(self):
  76.         pass
  77.  
  78. class AggrExceptionInStep:
  79.     def __init__(self):
  80.         pass
  81.  
  82.     def step(self, x):
  83.         5/0
  84.  
  85.     def finalize(self):
  86.         return 42
  87.  
  88. class AggrExceptionInFinalize:
  89.     def __init__(self):
  90.         pass
  91.  
  92.     def step(self, x):
  93.         pass
  94.  
  95.     def finalize(self):
  96.         5/0
  97.  
  98. class AggrCheckType:
  99.     def __init__(self):
  100.         self.val = None
  101.  
  102.     def step(self, whichType, val):
  103.         theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
  104.         self.val = int(theType[whichType] is type(val))
  105.  
  106.     def finalize(self):
  107.         return self.val
  108.  
  109. class AggrSum:
  110.     def __init__(self):
  111.         self.val = 0.0
  112.  
  113.     def step(self, val):
  114.         self.val += val
  115.  
  116.     def finalize(self):
  117.         return self.val
  118.  
  119. class FunctionTests(unittest.TestCase):
  120.     def setUp(self):
  121.         self.con = sqlite.connect(":memory:")
  122.  
  123.         self.con.create_function("returntext", 0, func_returntext)
  124.         self.con.create_function("returnunicode", 0, func_returnunicode)
  125.         self.con.create_function("returnint", 0, func_returnint)
  126.         self.con.create_function("returnfloat", 0, func_returnfloat)
  127.         self.con.create_function("returnnull", 0, func_returnnull)
  128.         self.con.create_function("returnblob", 0, func_returnblob)
  129.         self.con.create_function("raiseexception", 0, func_raiseexception)
  130.  
  131.         self.con.create_function("isstring", 1, func_isstring)
  132.         self.con.create_function("isint", 1, func_isint)
  133.         self.con.create_function("isfloat", 1, func_isfloat)
  134.         self.con.create_function("isnone", 1, func_isnone)
  135.         self.con.create_function("isblob", 1, func_isblob)
  136.  
  137.     def tearDown(self):
  138.         self.con.close()
  139.  
  140.     def CheckFuncErrorOnCreate(self):
  141.         try:
  142.             self.con.create_function("bla", -100, lambda x: 2*x)
  143.             self.fail("should have raised an OperationalError")
  144.         except sqlite.OperationalError:
  145.             pass
  146.  
  147.     def CheckFuncRefCount(self):
  148.         def getfunc():
  149.             def f():
  150.                 return 1
  151.             return f
  152.         f = getfunc()
  153.         globals()["foo"] = f
  154.         # self.con.create_function("reftest", 0, getfunc())
  155.         self.con.create_function("reftest", 0, f)
  156.         cur = self.con.cursor()
  157.         cur.execute("select reftest()")
  158.  
  159.     def CheckFuncReturnText(self):
  160.         cur = self.con.cursor()
  161.         cur.execute("select returntext()")
  162.         val = cur.fetchone()[0]
  163.         self.failUnlessEqual(type(val), unicode)
  164.         self.failUnlessEqual(val, "foo")
  165.  
  166.     def CheckFuncReturnUnicode(self):
  167.         cur = self.con.cursor()
  168.         cur.execute("select returnunicode()")
  169.         val = cur.fetchone()[0]
  170.         self.failUnlessEqual(type(val), unicode)
  171.         self.failUnlessEqual(val, u"bar")
  172.  
  173.     def CheckFuncReturnInt(self):
  174.         cur = self.con.cursor()
  175.         cur.execute("select returnint()")
  176.         val = cur.fetchone()[0]
  177.         self.failUnlessEqual(type(val), int)
  178.         self.failUnlessEqual(val, 42)
  179.  
  180.     def CheckFuncReturnFloat(self):
  181.         cur = self.con.cursor()
  182.         cur.execute("select returnfloat()")
  183.         val = cur.fetchone()[0]
  184.         self.failUnlessEqual(type(val), float)
  185.         if val < 3.139 or val > 3.141:
  186.             self.fail("wrong value")
  187.  
  188.     def CheckFuncReturnNull(self):
  189.         cur = self.con.cursor()
  190.         cur.execute("select returnnull()")
  191.         val = cur.fetchone()[0]
  192.         self.failUnlessEqual(type(val), type(None))
  193.         self.failUnlessEqual(val, None)
  194.  
  195.     def CheckFuncReturnBlob(self):
  196.         cur = self.con.cursor()
  197.         cur.execute("select returnblob()")
  198.         val = cur.fetchone()[0]
  199.         self.failUnlessEqual(type(val), buffer)
  200.         self.failUnlessEqual(val, buffer("blob"))
  201.  
  202.     def CheckFuncException(self):
  203.         cur = self.con.cursor()
  204.         try:
  205.             cur.execute("select raiseexception()")
  206.             cur.fetchone()
  207.             self.fail("should have raised OperationalError")
  208.         except sqlite.OperationalError, e:
  209.             self.failUnlessEqual(e.args[0], 'user-defined function raised exception')
  210.  
  211.     def CheckParamString(self):
  212.         cur = self.con.cursor()
  213.         cur.execute("select isstring(?)", ("foo",))
  214.         val = cur.fetchone()[0]
  215.         self.failUnlessEqual(val, 1)
  216.  
  217.     def CheckParamInt(self):
  218.         cur = self.con.cursor()
  219.         cur.execute("select isint(?)", (42,))
  220.         val = cur.fetchone()[0]
  221.         self.failUnlessEqual(val, 1)
  222.  
  223.     def CheckParamFloat(self):
  224.         cur = self.con.cursor()
  225.         cur.execute("select isfloat(?)", (3.14,))
  226.         val = cur.fetchone()[0]
  227.         self.failUnlessEqual(val, 1)
  228.  
  229.     def CheckParamNone(self):
  230.         cur = self.con.cursor()
  231.         cur.execute("select isnone(?)", (None,))
  232.         val = cur.fetchone()[0]
  233.         self.failUnlessEqual(val, 1)
  234.  
  235.     def CheckParamBlob(self):
  236.         cur = self.con.cursor()
  237.         cur.execute("select isblob(?)", (buffer("blob"),))
  238.         val = cur.fetchone()[0]
  239.         self.failUnlessEqual(val, 1)
  240.  
  241. class AggregateTests(unittest.TestCase):
  242.     def setUp(self):
  243.         self.con = sqlite.connect(":memory:")
  244.         cur = self.con.cursor()
  245.         cur.execute("""
  246.             create table test(
  247.                 t text,
  248.                 i integer,
  249.                 f float,
  250.                 n,
  251.                 b blob
  252.                 )
  253.             """)
  254.         cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
  255.             ("foo", 5, 3.14, None, buffer("blob"),))
  256.  
  257.         self.con.create_aggregate("nostep", 1, AggrNoStep)
  258.         self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
  259.         self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
  260.         self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
  261.         self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
  262.         self.con.create_aggregate("checkType", 2, AggrCheckType)
  263.         self.con.create_aggregate("mysum", 1, AggrSum)
  264.  
  265.     def tearDown(self):
  266.         #self.cur.close()
  267.         #self.con.close()
  268.         pass
  269.  
  270.     def CheckAggrErrorOnCreate(self):
  271.         try:
  272.             self.con.create_function("bla", -100, AggrSum)
  273.             self.fail("should have raised an OperationalError")
  274.         except sqlite.OperationalError:
  275.             pass
  276.  
  277.     def CheckAggrNoStep(self):
  278.         cur = self.con.cursor()
  279.         try:
  280.             cur.execute("select nostep(t) from test")
  281.             self.fail("should have raised an AttributeError")
  282.         except AttributeError, e:
  283.             self.failUnlessEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
  284.  
  285.     def CheckAggrNoFinalize(self):
  286.         cur = self.con.cursor()
  287.         try:
  288.             cur.execute("select nofinalize(t) from test")
  289.             val = cur.fetchone()[0]
  290.             self.fail("should have raised an OperationalError")
  291.         except sqlite.OperationalError, e:
  292.             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
  293.  
  294.     def CheckAggrExceptionInInit(self):
  295.         cur = self.con.cursor()
  296.         try:
  297.             cur.execute("select excInit(t) from test")
  298.             val = cur.fetchone()[0]
  299.             self.fail("should have raised an OperationalError")
  300.         except sqlite.OperationalError, e:
  301.             self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
  302.  
  303.     def CheckAggrExceptionInStep(self):
  304.         cur = self.con.cursor()
  305.         try:
  306.             cur.execute("select excStep(t) from test")
  307.             val = cur.fetchone()[0]
  308.             self.fail("should have raised an OperationalError")
  309.         except sqlite.OperationalError, e:
  310.             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
  311.  
  312.     def CheckAggrExceptionInFinalize(self):
  313.         cur = self.con.cursor()
  314.         try:
  315.             cur.execute("select excFinalize(t) from test")
  316.             val = cur.fetchone()[0]
  317.             self.fail("should have raised an OperationalError")
  318.         except sqlite.OperationalError, e:
  319.             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
  320.  
  321.     def CheckAggrCheckParamStr(self):
  322.         cur = self.con.cursor()
  323.         cur.execute("select checkType('str', ?)", ("foo",))
  324.         val = cur.fetchone()[0]
  325.         self.failUnlessEqual(val, 1)
  326.  
  327.     def CheckAggrCheckParamInt(self):
  328.         cur = self.con.cursor()
  329.         cur.execute("select checkType('int', ?)", (42,))
  330.         val = cur.fetchone()[0]
  331.         self.failUnlessEqual(val, 1)
  332.  
  333.     def CheckAggrCheckParamFloat(self):
  334.         cur = self.con.cursor()
  335.         cur.execute("select checkType('float', ?)", (3.14,))
  336.         val = cur.fetchone()[0]
  337.         self.failUnlessEqual(val, 1)
  338.  
  339.     def CheckAggrCheckParamNone(self):
  340.         cur = self.con.cursor()
  341.         cur.execute("select checkType('None', ?)", (None,))
  342.         val = cur.fetchone()[0]
  343.         self.failUnlessEqual(val, 1)
  344.  
  345.     def CheckAggrCheckParamBlob(self):
  346.         cur = self.con.cursor()
  347.         cur.execute("select checkType('blob', ?)", (buffer("blob"),))
  348.         val = cur.fetchone()[0]
  349.         self.failUnlessEqual(val, 1)
  350.  
  351.     def CheckAggrCheckAggrSum(self):
  352.         cur = self.con.cursor()
  353.         cur.execute("delete from test")
  354.         cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
  355.         cur.execute("select mysum(i) from test")
  356.         val = cur.fetchone()[0]
  357.         self.failUnlessEqual(val, 60)
  358.  
  359. def authorizer_cb(action, arg1, arg2, dbname, source):
  360.     if action != sqlite.SQLITE_SELECT:
  361.         return sqlite.SQLITE_DENY
  362.     if arg2 == 'c2' or arg1 == 't2':
  363.         return sqlite.SQLITE_DENY
  364.     return sqlite.SQLITE_OK
  365.  
  366. class AuthorizerTests(unittest.TestCase):
  367.     def setUp(self):
  368.         self.con = sqlite.connect(":memory:")
  369.         self.con.executescript("""
  370.             create table t1 (c1, c2);
  371.             create table t2 (c1, c2);
  372.             insert into t1 (c1, c2) values (1, 2);
  373.             insert into t2 (c1, c2) values (4, 5);
  374.             """)
  375.  
  376.         # For our security test:
  377.         self.con.execute("select c2 from t2")
  378.  
  379.         self.con.set_authorizer(authorizer_cb)
  380.  
  381.     def tearDown(self):
  382.         pass
  383.  
  384.     def CheckTableAccess(self):
  385.         try:
  386.             self.con.execute("select * from t2")
  387.         except sqlite.DatabaseError, e:
  388.             if not e.args[0].endswith("prohibited"):
  389.                 self.fail("wrong exception text: %s" % e.args[0])
  390.             return
  391.         self.fail("should have raised an exception due to missing privileges")
  392.  
  393.     def CheckColumnAccess(self):
  394.         try:
  395.             self.con.execute("select c2 from t1")
  396.         except sqlite.DatabaseError, e:
  397.             if not e.args[0].endswith("prohibited"):
  398.                 self.fail("wrong exception text: %s" % e.args[0])
  399.             return
  400.         self.fail("should have raised an exception due to missing privileges")
  401.  
  402. def suite():
  403.     function_suite = unittest.makeSuite(FunctionTests, "Check")
  404.     aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
  405.     authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check")
  406.     return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite))
  407.  
  408. def test():
  409.     runner = unittest.TextTestRunner()
  410.     runner.run(suite())
  411.  
  412. if __name__ == "__main__":
  413.     test()
  414.